# forward-compat boilerplate
from __future__ import absolute_import
from __future__ import with_statement
__metaclass__ = type

import os
import sqlalchemy as sql
from sqlalchemy import orm

METADATA = sql.MetaData()
ENV_KEY = 'FENTON_UPGRADING'

SCHEMA_HISTORY =  sql.Table(
    'SCHEMA_HISTORY',
    METADATA,
    sql.Column('version', sql.Integer(), primary_key=True, nullable=False),
    sql.Column('author',  sql.String(), nullable=False),
    sql.Column('stamp',   sql.DateTime(timezone=False), nullable=False),
)


def initialize(app):
    from fenton import logging
    os.putenv('PGTZ', 'UTC')
    name = app.config['fenton.db'] + '.db.'
    engine = sql.engine_from_config(app.config, name)
    app.upgrade_path = app.config['upgrade.path']
    if not upgrading():
        con = engine.connect()
        if find_upgrade(con, app.upgrade_path):
            logging.log.error('upgrade pending')
        con.close()
    DB = orm.sessionmaker(bind=engine, autoflush=False)
    DB.bind = engine
    return DB


# command
def upgrade_status(cx):
    '0 => no upgrade; 99 => online; 100 => offline'
    up = find_upgrade(cx.db, cx.app.upgrade_path)
    return up and (up.online and 99 or 100) or 0


# command
def upgrade(cx, commit=False):
    up = find_upgrade(cx.db, cx.app.upgrade_path)
    if not up:
        print 'Nothing to do'
        return 1
    cx = UpgradeContext(cx)
    try:
        with cx:
            do_upgrade(cx, up)
            print 'Upgraded to version', up.version
            if not commit:
                raise NoCommit
    except NoCommit:
        print 'ROLLBACK'
    else:
        print 'COMMIT'


# decorator
def post_upgrade(f):
    UpgradeContext.post_upgraders.append(f)
    return f


def upgrading(set=None):
    if set:
        os.environ[ENV_KEY] = 'true'
    else:
        return bool(os.environ.get(ENV_KEY))


def get_version(con):
    v = SCHEMA_HISTORY.c.version
    q = sql.select([v]).order_by(v.desc()).limit(1)
    return con.execute(q).scalar() or 0


def get_script(v, p):
    ext = os.path.splitext(p)[1][1:].lower()
    C = ext == 'sql' and SqlScript or PyScript
    return C(v, p)


def do_upgrade(cx, upgrade):
    for f in [upgrade] + cx.post_upgraders:
        print f
        f(cx)
        cx.flush()

    q = SCHEMA_HISTORY.insert().values(version=upgrade.version,
                                       author=cx.user.username)
    cx.execute(q)
    cx.flush()


def grant(cx, privs, user):
    run = cx.bind.execute
    objs = run('''
        SELECT relname
        FROM pg_catalog.pg_class c,
             pg_catalog.pg_namespace n
        WHERE c.relkind IN ('r', 'v', 'S')
            AND n.oid = c.relnamespace
            AND n.nspname = 'public' ''')
    objs = [row[0] for row in objs]
    for obj in objs:
        run('GRANT %s ON "%s" TO "%s" ' % (privs, obj, user))


def find_upgrade(con, path, _memo=[]):

    if _memo:
        return _memo[1]

    _memo.append(None)
    OP = os.path
    upgrades = {}
    pending = None
    version = get_version(con) + 1

    if not OP.exists(path):
        print 'Directory %s does not exist' % path

    elif not OP.isdir(path):
        print 'Not a directory:', path

    v = None
    for d, _, files in os.walk(path, followlinks=True):
        for n in files:
            p = OP.join(d, n)
            if not OP.isfile(p):
                continue
            v = OP.splitext(n)[0]
            try:
                v = int(v)
            except ValueError:
                if v == 'next':
                    v = version
                else:
                    continue
            if v < version:
                continue
            if v in upgrades:
                print 'Duplicate upgrade:', p
                continue
            upgrades[v] = p

    try:
        pending = upgrades.pop(version)
    except KeyError:
        pass
    else:
        pending = get_script(version, pending)
    if upgrades:
        print 'Skipping unexpected upgrades:'
        for v, p in sorted(upgrades.iteritems()):
            print p

    _memo.append(pending)
    return _memo[1]


class NoCommit(Exception):
    'Marker for non-committed upgrade'


class Upgrade:
    def __init__(self, version, path):
        self.version = version
        self.path = path
        self.init()
    def __call__(self, cx):
        return self.execute(cx)
    def __repr__(self):
        return '<Version %d:%s>' % (self.version, self.path)
    def init(self):
        pass


class PyScript(Upgrade):
    online = property(lambda x:x.module.online)
    def init(self):
        src = open(self.path, 'rb').read()
        code = compile(src, self.path, 'exec')
        import imp
        self.module = mod = imp.new_module('version-%d' % self.version)
        self.module.online = False
        mod.__file__ = self.path
        exec code in mod.__dict__

    def execute(self, cx):
        self.module.upgrade(cx)


class SqlScript(Upgrade):
    def execute(self, cx):
        sql = open(self.filename).read()
        cx.run(sql.replace('%', '%%'))


class UpgradeContext:
    post_upgraders = []
    def __init__(self, cx):
        self.__cx = cx
        name = cx.app.config['upgrade.db'] + '.db.'
        engine = sql.engine_from_config(cx.app.config, name)
        cx.app.db.configure(bind=engine, autoflush=True)
        self.context = cx
        self.bind = cx.bind = cx.db.connection()
        self.db = self.bind.db = cx.db
        self.flush = cx.db.flush

    def __getattr__(self, name):
        return getattr(self.__cx, name)

    def __enter__(self):
        return self.__cx.__enter__()

    def __exit__(self, *errors):
        return self.__cx.__exit__(*errors)

    def execute(self, q, **kw):
        self.flush()
        return self.bind.execute(q, **kw)

    def runmany(self, *qq, **kw):
        r = None
        for q in qq:
            if isinstance(q, basestring):
                print q % (kw or {})
                print
            r = self.execute(q, **kw)
        return r

    def run(self, qq, **kw):
        if isinstance(qq, basestring):
            qq = [q for q in qq.split(';\n') if q.strip()]
        elif not isinstance(qq, list):
            qq = [qq]
        return self.runmany(*qq, **kw)

    def create(self, *classes):
        ts = [c.__table__ for c in classes]
        METADATA.create_all(bind=self.bind, tables=ts or None)

    def drop_columns(self, *colnames):
        if len(colnames) == 1 and '\n' in colnames[0]:
            colnames = colnames[0].split()
        for n in colnames:
            self.run('ALTER TABLE %s DROP %s CASCADE' % tuple(n.split('.', 1)))

    def drop_tables(self, *tablenames):
        if len(tablenames) == 1 and '\n' in tablenames[0]:
            tablenames = tablenames[0].split()
        for t in tablenames:
            self.run('DROP TABLE %s CASCADE' % t)
            self.delete_metainfo(t)

    def insert_metainfo(self, *classes):
        from fenton import data
        for c in classes:
            data._insert_metainfo(self, c)

    def delete_metainfo(self, *tablenames):
        from fenton import data
        C = data.MetaObjectClass
        q = C.__table__.delete()
        for t in tablenames:
            self.run(q.where(C.tablename == t))




